def make_spice_style_train(lm, n_samples, max_len, filename):
    prepare_directory(filename, includes_filename=True)
    with open(filename, "w") as f:
        print(n_samples, len(lm.internal_alphabet), file=f)
        for _ in range(n_samples):
            s = lm.sample(cutoff=max_len)
            print(len(s), *s, file=f)
Exemple #2
0
def save_rnn(full_rnn_folder, rnn, optimiser=None):
    was_cuda = rnn.rnn_module.using_cuda
    rnn.cpu(
        just_saving=True
    )  # seems like a better idea if want to open on another computer later
    prepare_directory(full_rnn_folder, includes_filename=False)
    overwrite_file(rnn._to_dict(), dict_file(full_rnn_folder))
    torch.save(rnn.rnn_module.state_dict(), module_file(full_rnn_folder))
    if not None is optimiser:
        torch.save(optimiser.state_dict(), optimiser_file(full_rnn_folder))
    pass
    if was_cuda:  # put it back!!
        rnn.cuda(returning_from_save=True)
Exemple #3
0
 def make_spice_preds(self, prefixes, filename=None):
     if None is filename:
         filename = "temporary_preds_" + str(time()) + ".txt"
     assert not None in prefixes
     prepare_directory(filename, includes_filename=True)
     with open(filename, "w") as f:
         for p in prefixes:
             state = self._state_from_sequence(p)
             preds = self._most_likely_token_from_state(
                 state, k=len(self.internal_alphabet
                              ))  # just list them all in decreasing order
             preds = [
                 str(t) if not t == self.model.end_token else "-1"
                 for t in preds
             ]
             f.write(" ".join(preds) + "\n")
     return filename
def do_ngram():
    print("~~~running ngram extraction~~~")
    print("making samples", end=" ... ")
    sample_start = process_time()
    samples = []
    length = 0
    lmrnn = LanguageModel(rnn)
    while length < args.ngram_total_sample_length:
        s = lmrnn.sample(cutoff=args.ngram_max_sample_length)
        samples.append(s)
        length += (len(s) + 1)  # ending the sequence is also a sample
    ngrams = {}
    ngrams_folder = rnn_folder + "/ngram"
    prepare_directory(ngrams_folder)
    sample_time = process_time() - sample_start
    print("done, that took:", clock_str(sample_start))
    print("making the actual ngrams", end=" ... ")
    with open(ngrams_folder + "/samples.txt", "w") as f:
        print(len(samples), len(rnn.internal_alphabet), file=f)
        for s in samples:
            print(len(s), *s, file=f)
    for n in args.ngram_ns:
        ngram_start = process_time()
        ngram = NGram(n, rnn.input_alphabet, samples)
        ngram.creation_info = {
            "extraction time": sample_time + process_time() - ngram_start,
            "size": len(ngram._state_probs_dist),
            "n": n,
            "total samples len (including EOS)": length,
            "num samples": len(samples),
            "samples cutoff len": args.ngram_max_sample_length
        }
        overwrite_file(ngram, ngrams_folder + "/" + str(n))
        ngrams[n] = ngram
    with open(ngrams_folder + "/creation_infos.txt", "w") as f:
        print("ngrams made from",
              len(samples),
              "samples, of total length",
              length,
              "(including EOSs)",
              file=f)
        for n in ngrams:
            print("===", n, "===\n", ngrams[n].creation_info, "\n\n", file=f)
    print("done, that took overall", clock_str(sample_start))
    return ngrams
def do_spectral():
    print("~~~running spectral extraction~~~")
    spectral_folder = rnn_folder + "/spectral_" + str(args.nPS)
    prepare_directory(spectral_folder)
    P,S = make_P_S(rnn,args.nPS,args.nPS,hard_stop=True,max_attempts=args.spectral_max_sample_attempts,\
          max_sample_length=args.spectral_max_sample_length)
    with open(spectral_folder + "/samples.txt", "w") as f:
        sample_start = process_time()
        print("making P,S with n_PS:", args.nPS, end="...")
        print("done, that took:", clock_str(sample_start))
        sampling_time = process_time() - sample_start

        print("P (", len(P), ") :\n\n", file=f)
        for p in P:
            print(*p, file=f)
        print("S (", len(S), ") :\n\n", file=f)
        for s in S:
            print(*s, file=f)
    with open(spectral_folder + "/spectral_prints.txt", "w") as f:
        print("getting P,S took:", sampling_time, file=f)
        wfas, times_excl_sampling, hankel_time, svd_time, _ = spectral_reconstruct(
            rnn, P, S, args.k_list, print_file=f)
        print("making hankels took:", hankel_time, file=f)
        print("running svd took:", svd_time, file=f)
        generic_creation_info = {
            "|P|": len(P),
            "|S|": len(S),
            "rnn name": rnn.name,
            "hankel time": hankel_time,
            "svd time": svd_time
        }  # ,"k":wfa.n} # "extraction time":total_time+PStime
        for wfa, t in zip(wfas, times_excl_sampling):
            wfa.creation_info = generic_creation_info
            wfa.creation_info["k"] = wfa.n
            wfa.creation_info["extraction time"] = t + sampling_time
            print("\n\n", wfa.n, "\n\n", wfa.creation_info, file=f)
            overwrite_file(wfa, spectral_folder + "/" + str(wfa.n))
    print("done, that took overall", clock_str(sample_start))
    return wfas
def do_lstar():
    print("~~~running weighted lstar extraction~~~")
    lstar_folder = rnn_folder + "/lstar"
    prepare_directory(lstar_folder)
    lstar_start = process_time()
    lstar_prints_filename = lstar_folder + "/extraction_prints.txt"
    print("progress prints will be in:", lstar_prints_filename)
    with open(lstar_prints_filename, "w") as f:
        lstar_pdfa,table,minimiser = learn(rnn,
         max_states = args.max_states,
         max_P = args.max_P,
         max_S=args.max_S,
         pdfas_path = lstar_folder,
         prints_path = f,
         atol = args.t_tol,
         interval_width=args.interval_width,
         n_cex_attempts=args.num_cex_attempts,
         max_counterexample_length=args.max_counterexample_length,
         expanding_time_limit=args.lstar_time_limit,\
         s_separating_threshold=args.lstar_s_threshold,\
         interesting_p_transition_threshold=args.lstar_p_threshold,\
         progress_P_print_rate=args.progress_P_print_rate)
    lstar_pdfa.creation_info = {
        "extraction time": process_time() - lstar_start,
        "size": len(lstar_pdfa.transitions)
    }
    lstar_pdfa.creation_info.update(
        vars(args)
    )  # get all the extraction hyperparams as well, though this will also catch other hyperparams like the ngrams and stuff..
    overwrite_file(lstar_pdfa, lstar_folder + "/pdfa")  # will end up in .gz
    with open(lstar_folder + "/extraction_info.txt", "w") as f:
        print(lstar_pdfa.creation_info, file=f)
    lstar_pdfa.draw_nicely(keep=True, filename=lstar_folder +
                           "/pdfa")  # will end up in .img
    print("finished lstar extraction, that took:", clock_str(lstar_start))
    return lstar_pdfa
Exemple #7
0
def train_rnn(rnn,train_set,validation_set,full_rnn_folder,iterations_per_learning_rate=500,learning_rates=None,\
 batch_size=100,\
 check_improvement_every=1,step_size_for_progress_checks=200,\
 progress_seqs_at_a_time=100,track_train_loss=False,ignore_prev_best_losses=False):
    # iterations per learning rate: number of iterations to spend on each learning rate, eg [10,20,30]. learning rates: normally start at 0.001 and do like 0.7 for the decays?
    def periodic_stats(check_improvement_counter):
        break_iter = False
        check_improvement_counter += 1
        if check_improvement_counter % check_improvement_every == 0:
            if not ti.validation_and_train_improving_overall(
                    already_know_it_didnt_here=had_error):
                break_iter = True
                return check_improvement_counter, break_iter
            rnn.plot_stats_history(plots_path=full_rnn_folder +
                                   "/training_plots")
        return check_improvement_counter, break_iter

    def finish():
        rnn = ti.reload_and_get_best_rnn(with_optimiser=False)
        save_rnn(full_rnn_folder, rnn)
        ti.delete_metadata()
        rnn.total_train_time += (process_time() - start
                                 )  # add all the time we wasted in here
        if not track_train_loss:  # at least compute the last one
            rnn.training_losses[-1] = rnn.detached_average_loss_on_group(
                train_set,
                step_size=step_size_for_progress_checks,
                seqs_at_a_time=progress_seqs_at_a_time)
        print("reached average training loss of:",
              rnn.training_losses[-1],
              file=training_prints_file,
              flush=True)
        print("and average validation loss of:",
              rnn.validation_losses[-1],
              file=training_prints_file,
              flush=True)
        print(
            "overall time spent training, including those dropped to validation:",
            process_time() - start,
            file=training_prints_file,
            flush=True)
        rnn.plot_stats_history(plots_path=full_rnn_folder + "/training_plots")
        return rnn

    training_prints_filename = full_rnn_folder + "/training_prints.txt"
    if not train_set:  # empty train set
        with open(training_prints_filename, "a") as f:
            print("train set empty, doing nothing", file=f)
        return rnn
    iterations_per_learning_rate, learning_rates = _check_learning_rates_and_iterations(
        iterations_per_learning_rate, learning_rates)

    prepare_directory(full_rnn_folder, includes_filename=False)
    check_improvement_counter = 0
    with open(training_prints_filename, "a") as training_prints_file:
        print("training rnn:",
              rnn.informal_name,
              file=training_prints_file,
              flush=True)
        start = process_time()
        # print("current rnn train time is:",rnn.total_train_time)

        ti = TrainingInfo(rnn,validation_set,train_set,0,training_prints_file,full_rnn_folder,start,track_train_loss,\
              ignore_prev_best_losses,step_size_for_progress_checks,progress_seqs_at_a_time)
        representative_word = train_set[0]
        try:
            l_r_c = 0
            for learning_rate, iterations in zip(learning_rates,
                                                 iterations_per_learning_rate):
                ti.trainer = torch.optim.Adam(rnn.rnn_module.parameters(),
                                              lr=learning_rate)
                l_r_c += 1
                ti.batches_since_validation_improvement = 0
                for i in range(iterations):
                    batches = make_shuffled_batches(train_set, batch_size)
                    ti.batches_before_validation_cutoff = int(
                        len(batches) / check_improvement_every)
                    batch_c = 0
                    for b in batches:
                        batch_c += 1
                        print("learning rate",l_r_c,"of",len(learning_rates),"(",clean_val(learning_rate,6),\
                         "), iteration",i+1,"of",iterations,", batch",batch_c,"of",len(batches),file=training_prints_file,flush=True,end="")
                        had_error = False
                        try:
                            batch_start = process_time()
                            train_specific_batch(rnn, b, learning_rate,
                                                 ti.trainer)
                            print(" finished, that took:",
                                  clean_val(process_time() - batch_start),
                                  file=training_prints_file,
                                  flush=True)
                        except RuntimeError as e:
                            print("\ntraining with learning rate",
                                  learning_rate,
                                  "hit error:",
                                  str(e),
                                  file=training_prints_file,
                                  flush=True)
                            rnn = ti.reload_and_get_best_rnn(
                                learning_rate=learning_rate
                            )  # something went wrong, get the best rnn back
                            had_error = True
                        rnn._update_stats_except_train_and_validation(
                            representative_word
                        )  # weird partial updates here but w/e
                        check_improvement_counter, break_iter = periodic_stats(
                            check_improvement_counter)
                        if break_iter:
                            break
                    if break_iter:
                        break

        except KeyboardInterrupt:
            print(
                "stopped by user - losses may be different than those last recorded",
                file=training_prints_file,
                flush=True)
            save_rnn(full_rnn_folder + "/last_before_interrupt", rnn)
        if not (
                check_improvement_counter - 1
        ) % check_improvement_every == 0:  # i.e. didn't literally just check stats
            periodic_stats(0)

        return finish()
    args.num_cex_attempts = 20
    args.max_P = 50
    args.max_S = 20
    args.lstar_time_limit = 20
    args.ngram_total_sample_length = 1e3
    args.ndcg_num_samples = 100
    args.wer_num_samples = 100
    args.wer_max_len = 10

args.k_list = []
for t in args.k_ranges:
    args.k_list += list(range(*t))

uhl = {1: uhl1(), 2: uhl2(), 3: uhl3()}
folder = "results"
prepare_directory(folder)


def make_spice_style_train(lm, n_samples, max_len, filename):
    prepare_directory(filename, includes_filename=True)
    with open(filename, "w") as f:
        print(n_samples, len(lm.internal_alphabet), file=f)
        for _ in range(n_samples):
            s = lm.sample(cutoff=max_len)
            print(len(s), *s, file=f)


def read_spice_style_train_data(filename):
    print("loading from file:", filename)
    if not os.path.exists(filename):
        return None, None