def make_spice_style_train(lm, n_samples, max_len, filename): prepare_directory(filename, includes_filename=True) with open(filename, "w") as f: print(n_samples, len(lm.internal_alphabet), file=f) for _ in range(n_samples): s = lm.sample(cutoff=max_len) print(len(s), *s, file=f)
def save_rnn(full_rnn_folder, rnn, optimiser=None): was_cuda = rnn.rnn_module.using_cuda rnn.cpu( just_saving=True ) # seems like a better idea if want to open on another computer later prepare_directory(full_rnn_folder, includes_filename=False) overwrite_file(rnn._to_dict(), dict_file(full_rnn_folder)) torch.save(rnn.rnn_module.state_dict(), module_file(full_rnn_folder)) if not None is optimiser: torch.save(optimiser.state_dict(), optimiser_file(full_rnn_folder)) pass if was_cuda: # put it back!! rnn.cuda(returning_from_save=True)
def make_spice_preds(self, prefixes, filename=None): if None is filename: filename = "temporary_preds_" + str(time()) + ".txt" assert not None in prefixes prepare_directory(filename, includes_filename=True) with open(filename, "w") as f: for p in prefixes: state = self._state_from_sequence(p) preds = self._most_likely_token_from_state( state, k=len(self.internal_alphabet )) # just list them all in decreasing order preds = [ str(t) if not t == self.model.end_token else "-1" for t in preds ] f.write(" ".join(preds) + "\n") return filename
def do_ngram(): print("~~~running ngram extraction~~~") print("making samples", end=" ... ") sample_start = process_time() samples = [] length = 0 lmrnn = LanguageModel(rnn) while length < args.ngram_total_sample_length: s = lmrnn.sample(cutoff=args.ngram_max_sample_length) samples.append(s) length += (len(s) + 1) # ending the sequence is also a sample ngrams = {} ngrams_folder = rnn_folder + "/ngram" prepare_directory(ngrams_folder) sample_time = process_time() - sample_start print("done, that took:", clock_str(sample_start)) print("making the actual ngrams", end=" ... ") with open(ngrams_folder + "/samples.txt", "w") as f: print(len(samples), len(rnn.internal_alphabet), file=f) for s in samples: print(len(s), *s, file=f) for n in args.ngram_ns: ngram_start = process_time() ngram = NGram(n, rnn.input_alphabet, samples) ngram.creation_info = { "extraction time": sample_time + process_time() - ngram_start, "size": len(ngram._state_probs_dist), "n": n, "total samples len (including EOS)": length, "num samples": len(samples), "samples cutoff len": args.ngram_max_sample_length } overwrite_file(ngram, ngrams_folder + "/" + str(n)) ngrams[n] = ngram with open(ngrams_folder + "/creation_infos.txt", "w") as f: print("ngrams made from", len(samples), "samples, of total length", length, "(including EOSs)", file=f) for n in ngrams: print("===", n, "===\n", ngrams[n].creation_info, "\n\n", file=f) print("done, that took overall", clock_str(sample_start)) return ngrams
def do_spectral(): print("~~~running spectral extraction~~~") spectral_folder = rnn_folder + "/spectral_" + str(args.nPS) prepare_directory(spectral_folder) P,S = make_P_S(rnn,args.nPS,args.nPS,hard_stop=True,max_attempts=args.spectral_max_sample_attempts,\ max_sample_length=args.spectral_max_sample_length) with open(spectral_folder + "/samples.txt", "w") as f: sample_start = process_time() print("making P,S with n_PS:", args.nPS, end="...") print("done, that took:", clock_str(sample_start)) sampling_time = process_time() - sample_start print("P (", len(P), ") :\n\n", file=f) for p in P: print(*p, file=f) print("S (", len(S), ") :\n\n", file=f) for s in S: print(*s, file=f) with open(spectral_folder + "/spectral_prints.txt", "w") as f: print("getting P,S took:", sampling_time, file=f) wfas, times_excl_sampling, hankel_time, svd_time, _ = spectral_reconstruct( rnn, P, S, args.k_list, print_file=f) print("making hankels took:", hankel_time, file=f) print("running svd took:", svd_time, file=f) generic_creation_info = { "|P|": len(P), "|S|": len(S), "rnn name": rnn.name, "hankel time": hankel_time, "svd time": svd_time } # ,"k":wfa.n} # "extraction time":total_time+PStime for wfa, t in zip(wfas, times_excl_sampling): wfa.creation_info = generic_creation_info wfa.creation_info["k"] = wfa.n wfa.creation_info["extraction time"] = t + sampling_time print("\n\n", wfa.n, "\n\n", wfa.creation_info, file=f) overwrite_file(wfa, spectral_folder + "/" + str(wfa.n)) print("done, that took overall", clock_str(sample_start)) return wfas
def do_lstar(): print("~~~running weighted lstar extraction~~~") lstar_folder = rnn_folder + "/lstar" prepare_directory(lstar_folder) lstar_start = process_time() lstar_prints_filename = lstar_folder + "/extraction_prints.txt" print("progress prints will be in:", lstar_prints_filename) with open(lstar_prints_filename, "w") as f: lstar_pdfa,table,minimiser = learn(rnn, max_states = args.max_states, max_P = args.max_P, max_S=args.max_S, pdfas_path = lstar_folder, prints_path = f, atol = args.t_tol, interval_width=args.interval_width, n_cex_attempts=args.num_cex_attempts, max_counterexample_length=args.max_counterexample_length, expanding_time_limit=args.lstar_time_limit,\ s_separating_threshold=args.lstar_s_threshold,\ interesting_p_transition_threshold=args.lstar_p_threshold,\ progress_P_print_rate=args.progress_P_print_rate) lstar_pdfa.creation_info = { "extraction time": process_time() - lstar_start, "size": len(lstar_pdfa.transitions) } lstar_pdfa.creation_info.update( vars(args) ) # get all the extraction hyperparams as well, though this will also catch other hyperparams like the ngrams and stuff.. overwrite_file(lstar_pdfa, lstar_folder + "/pdfa") # will end up in .gz with open(lstar_folder + "/extraction_info.txt", "w") as f: print(lstar_pdfa.creation_info, file=f) lstar_pdfa.draw_nicely(keep=True, filename=lstar_folder + "/pdfa") # will end up in .img print("finished lstar extraction, that took:", clock_str(lstar_start)) return lstar_pdfa
def train_rnn(rnn,train_set,validation_set,full_rnn_folder,iterations_per_learning_rate=500,learning_rates=None,\ batch_size=100,\ check_improvement_every=1,step_size_for_progress_checks=200,\ progress_seqs_at_a_time=100,track_train_loss=False,ignore_prev_best_losses=False): # iterations per learning rate: number of iterations to spend on each learning rate, eg [10,20,30]. learning rates: normally start at 0.001 and do like 0.7 for the decays? def periodic_stats(check_improvement_counter): break_iter = False check_improvement_counter += 1 if check_improvement_counter % check_improvement_every == 0: if not ti.validation_and_train_improving_overall( already_know_it_didnt_here=had_error): break_iter = True return check_improvement_counter, break_iter rnn.plot_stats_history(plots_path=full_rnn_folder + "/training_plots") return check_improvement_counter, break_iter def finish(): rnn = ti.reload_and_get_best_rnn(with_optimiser=False) save_rnn(full_rnn_folder, rnn) ti.delete_metadata() rnn.total_train_time += (process_time() - start ) # add all the time we wasted in here if not track_train_loss: # at least compute the last one rnn.training_losses[-1] = rnn.detached_average_loss_on_group( train_set, step_size=step_size_for_progress_checks, seqs_at_a_time=progress_seqs_at_a_time) print("reached average training loss of:", rnn.training_losses[-1], file=training_prints_file, flush=True) print("and average validation loss of:", rnn.validation_losses[-1], file=training_prints_file, flush=True) print( "overall time spent training, including those dropped to validation:", process_time() - start, file=training_prints_file, flush=True) rnn.plot_stats_history(plots_path=full_rnn_folder + "/training_plots") return rnn training_prints_filename = full_rnn_folder + "/training_prints.txt" if not train_set: # empty train set with open(training_prints_filename, "a") as f: print("train set empty, doing nothing", file=f) return rnn iterations_per_learning_rate, learning_rates = _check_learning_rates_and_iterations( iterations_per_learning_rate, learning_rates) prepare_directory(full_rnn_folder, includes_filename=False) check_improvement_counter = 0 with open(training_prints_filename, "a") as training_prints_file: print("training rnn:", rnn.informal_name, file=training_prints_file, flush=True) start = process_time() # print("current rnn train time is:",rnn.total_train_time) ti = TrainingInfo(rnn,validation_set,train_set,0,training_prints_file,full_rnn_folder,start,track_train_loss,\ ignore_prev_best_losses,step_size_for_progress_checks,progress_seqs_at_a_time) representative_word = train_set[0] try: l_r_c = 0 for learning_rate, iterations in zip(learning_rates, iterations_per_learning_rate): ti.trainer = torch.optim.Adam(rnn.rnn_module.parameters(), lr=learning_rate) l_r_c += 1 ti.batches_since_validation_improvement = 0 for i in range(iterations): batches = make_shuffled_batches(train_set, batch_size) ti.batches_before_validation_cutoff = int( len(batches) / check_improvement_every) batch_c = 0 for b in batches: batch_c += 1 print("learning rate",l_r_c,"of",len(learning_rates),"(",clean_val(learning_rate,6),\ "), iteration",i+1,"of",iterations,", batch",batch_c,"of",len(batches),file=training_prints_file,flush=True,end="") had_error = False try: batch_start = process_time() train_specific_batch(rnn, b, learning_rate, ti.trainer) print(" finished, that took:", clean_val(process_time() - batch_start), file=training_prints_file, flush=True) except RuntimeError as e: print("\ntraining with learning rate", learning_rate, "hit error:", str(e), file=training_prints_file, flush=True) rnn = ti.reload_and_get_best_rnn( learning_rate=learning_rate ) # something went wrong, get the best rnn back had_error = True rnn._update_stats_except_train_and_validation( representative_word ) # weird partial updates here but w/e check_improvement_counter, break_iter = periodic_stats( check_improvement_counter) if break_iter: break if break_iter: break except KeyboardInterrupt: print( "stopped by user - losses may be different than those last recorded", file=training_prints_file, flush=True) save_rnn(full_rnn_folder + "/last_before_interrupt", rnn) if not ( check_improvement_counter - 1 ) % check_improvement_every == 0: # i.e. didn't literally just check stats periodic_stats(0) return finish()
args.num_cex_attempts = 20 args.max_P = 50 args.max_S = 20 args.lstar_time_limit = 20 args.ngram_total_sample_length = 1e3 args.ndcg_num_samples = 100 args.wer_num_samples = 100 args.wer_max_len = 10 args.k_list = [] for t in args.k_ranges: args.k_list += list(range(*t)) uhl = {1: uhl1(), 2: uhl2(), 3: uhl3()} folder = "results" prepare_directory(folder) def make_spice_style_train(lm, n_samples, max_len, filename): prepare_directory(filename, includes_filename=True) with open(filename, "w") as f: print(n_samples, len(lm.internal_alphabet), file=f) for _ in range(n_samples): s = lm.sample(cutoff=max_len) print(len(s), *s, file=f) def read_spice_style_train_data(filename): print("loading from file:", filename) if not os.path.exists(filename): return None, None