Esempio n. 1
0
def encode_seq_seq_data(data : RawDataset,
                        context_tokenizer_type : Callable[[List[str], int], Tokenizer],
                        tactic_tokenizer_type : Callable[[List[str], int], Tokenizer],
                        num_keywords : int,
                        num_reserved_tokens : int) \
    -> Tuple[SequenceSequenceDataset, Tokenizer, Tokenizer]:
    context_tokenizer = make_keyword_tokenizer_topk(
        [context for prev_tactics, hyps, context, tactic in data],
        context_tokenizer_type, num_keywords, num_reserved_tokens)
    tactic_tokenizer = make_keyword_tokenizer_topk(
        [tactic for prev_tactics, hyps, context, tactic in data],
        tactic_tokenizer_type, num_keywords, num_reserved_tokens)
    result = [(context_tokenizer.toTokenList(context),
               tactic_tokenizer.toTokenList(tactic))
              for prev_tactics, hyps, context, tactic in data]
    context_tokenizer.freezeTokenList()
    tactic_tokenizer.freezeTokenList()
    return result, context_tokenizer, tactic_tokenizer
Esempio n. 2
0
def term_data(data : RawDataset,
              tokenizer_type : Callable[[List[str], int], Tokenizer],
              num_keywords : int,
              num_reserved_tokens : int) -> Tuple[TermDataset, Tokenizer]:
    term_strings = list(itertools.chain.from_iterable(
        [[hyp.split(":")[1].strip() for hyp in hyps] + [goal]
         for prev_tactics, hyps, goal, tactic in data]))
    tokenizer = make_keyword_tokenizer_topk(term_strings, tokenizer_type,
                                            num_keywords, num_reserved_tokens)
    return [tokenizer.toTokenList(term_string) for term_string in term_strings], \
        tokenizer
def main(args_list: List[str]) -> None:
    parser = argparse.ArgumentParser(description="Autoencoder for coq terms")
    add_std_args(parser)
    parser.add_argument("--gamma", default=.9, type=float)
    parser.add_argument("--epoch-step", default=5, type=int)
    parser.add_argument("--num-decoder-layers",
                        dest="num_decoder_layers",
                        default=3,
                        type=int)
    args = parser.parse_args(args_list)
    curtime = time.time()
    print("Loading data...", end="")
    sys.stdout.flush()
    dataset = list(
        itertools.islice(read_text_data(args.scrape_file), args.max_tuples))

    print(" {:.2f}s".format(time.time() - curtime))
    curtime = time.time()
    print("Extracting terms...", end="")
    sys.stdout.flush()
    term_strings = list(
        chain.from_iterable([[hyp.split(":")[1].strip()
                              for hyp in hyps] + [goal]
                             for prev_tactics, hyps, goal, tactic in dataset]))
    print(" {:.2f}s".format(time.time() - curtime))

    curtime = time.time()
    print("Building tokenizer...", end="")
    sys.stdout.flush()
    tokenizer = tk.make_keyword_tokenizer_topk(term_strings,
                                               tk.tokenizers[args.tokenizer],
                                               args.num_keywords, 2)
    print(" {:.2f}s".format(time.time() - curtime))
    curtime = time.time()
    print("Tokenizing {} strings...".format(len(term_strings)), end="")
    sys.stdout.flush()

    with multiprocessing.Pool(None) as pool:
        tokenized_data_chunks = pool.imap_unordered(
            functools.partial(use_tokenizer, tokenizer, args.max_length),
            chunks(term_strings, 32768))
        tokenized_data = list(chain.from_iterable(tokenized_data_chunks))

    print(" {:.2f}s".format(time.time() - curtime))
    checkpoints = train(tokenized_data, tokenizer.numTokens(), args.max_length,
                        args.hidden_size, args.learning_rate, args.epoch_step,
                        args.gamma, args.num_encoder_layers,
                        args.num_decoder_layers, args.num_epochs,
                        args.batch_size, args.print_every,
                        optimizers[args.optimizer])
    for epoch, (encoder_state, decoder_state,
                training_loss) in enumerate(checkpoints):
        state = {
            'epoch': epoch,
            'training-loss': training_loss,
            'tokenizer': tokenizer,
            'tokenizer-name': args.tokenizer,
            'optimizer': args.optimizer,
            'learning-rate': args.learning_rate,
            'encoder': encoder_state,
            'decoder': decoder_state,
            'num-encoder-layers': args.num_encoder_layers,
            'num-decoder-layers': args.num_decoder_layers,
            'max-length': args.max_length,
            'hidden-size': args.hidden_size,
            'num-keywords': args.num_keywords,
            'context-filter': args.context_filter,
        }
        with open(args.save_file, 'wb') as f:
            print("=> Saving checkpoint at epoch {}".format(epoch))
            torch.save(state, f)
    pass