def encode_seq_seq_data(data : RawDataset, context_tokenizer_type : Callable[[List[str], int], Tokenizer], tactic_tokenizer_type : Callable[[List[str], int], Tokenizer], num_keywords : int, num_reserved_tokens : int) \ -> Tuple[SequenceSequenceDataset, Tokenizer, Tokenizer]: context_tokenizer = make_keyword_tokenizer_topk( [context for prev_tactics, hyps, context, tactic in data], context_tokenizer_type, num_keywords, num_reserved_tokens) tactic_tokenizer = make_keyword_tokenizer_topk( [tactic for prev_tactics, hyps, context, tactic in data], tactic_tokenizer_type, num_keywords, num_reserved_tokens) result = [(context_tokenizer.toTokenList(context), tactic_tokenizer.toTokenList(tactic)) for prev_tactics, hyps, context, tactic in data] context_tokenizer.freezeTokenList() tactic_tokenizer.freezeTokenList() return result, context_tokenizer, tactic_tokenizer
def term_data(data : RawDataset, tokenizer_type : Callable[[List[str], int], Tokenizer], num_keywords : int, num_reserved_tokens : int) -> Tuple[TermDataset, Tokenizer]: term_strings = list(itertools.chain.from_iterable( [[hyp.split(":")[1].strip() for hyp in hyps] + [goal] for prev_tactics, hyps, goal, tactic in data])) tokenizer = make_keyword_tokenizer_topk(term_strings, tokenizer_type, num_keywords, num_reserved_tokens) return [tokenizer.toTokenList(term_string) for term_string in term_strings], \ tokenizer
def main(args_list: List[str]) -> None: parser = argparse.ArgumentParser(description="Autoencoder for coq terms") add_std_args(parser) parser.add_argument("--gamma", default=.9, type=float) parser.add_argument("--epoch-step", default=5, type=int) parser.add_argument("--num-decoder-layers", dest="num_decoder_layers", default=3, type=int) args = parser.parse_args(args_list) curtime = time.time() print("Loading data...", end="") sys.stdout.flush() dataset = list( itertools.islice(read_text_data(args.scrape_file), args.max_tuples)) print(" {:.2f}s".format(time.time() - curtime)) curtime = time.time() print("Extracting terms...", end="") sys.stdout.flush() term_strings = list( chain.from_iterable([[hyp.split(":")[1].strip() for hyp in hyps] + [goal] for prev_tactics, hyps, goal, tactic in dataset])) print(" {:.2f}s".format(time.time() - curtime)) curtime = time.time() print("Building tokenizer...", end="") sys.stdout.flush() tokenizer = tk.make_keyword_tokenizer_topk(term_strings, tk.tokenizers[args.tokenizer], args.num_keywords, 2) print(" {:.2f}s".format(time.time() - curtime)) curtime = time.time() print("Tokenizing {} strings...".format(len(term_strings)), end="") sys.stdout.flush() with multiprocessing.Pool(None) as pool: tokenized_data_chunks = pool.imap_unordered( functools.partial(use_tokenizer, tokenizer, args.max_length), chunks(term_strings, 32768)) tokenized_data = list(chain.from_iterable(tokenized_data_chunks)) print(" {:.2f}s".format(time.time() - curtime)) checkpoints = train(tokenized_data, tokenizer.numTokens(), args.max_length, args.hidden_size, args.learning_rate, args.epoch_step, args.gamma, args.num_encoder_layers, args.num_decoder_layers, args.num_epochs, args.batch_size, args.print_every, optimizers[args.optimizer]) for epoch, (encoder_state, decoder_state, training_loss) in enumerate(checkpoints): state = { 'epoch': epoch, 'training-loss': training_loss, 'tokenizer': tokenizer, 'tokenizer-name': args.tokenizer, 'optimizer': args.optimizer, 'learning-rate': args.learning_rate, 'encoder': encoder_state, 'decoder': decoder_state, 'num-encoder-layers': args.num_encoder_layers, 'num-decoder-layers': args.num_decoder_layers, 'max-length': args.max_length, 'hidden-size': args.hidden_size, 'num-keywords': args.num_keywords, 'context-filter': args.context_filter, } with open(args.save_file, 'wb') as f: print("=> Saving checkpoint at epoch {}".format(epoch)) torch.save(state, f) pass