def benchmark_experimental_vocab_lookup():
    def _run_benchmark_lookup(tokens, vocab):
        t0 = time.monotonic()
        for token in tokens:
            vocab[token]
        print("Lookup time:", time.monotonic() - t0)

    train, = AG_NEWS(data_select='train')
    vocab = train.get_vocab()
    tokens = []
    for (label, text) in train:
        for id in text.tolist():
            tokens.append(vocab.itos[id])

    counter = Counter(tokens)
    sorted_by_freq_tuples = sorted(counter.items(),
                                   key=lambda x: x[1],
                                   reverse=True)
    ordered_dict = OrderedDict(sorted_by_freq_tuples)

    # existing Vocab construction
    print("Vocab")
    t0 = time.monotonic()
    v_existing = Vocab(counter)
    print("Construction time:", time.monotonic() - t0)

    # experimental Vocab construction
    print("Vocab Experimental")
    t0 = time.monotonic()
    v_experimental = VocabExperimental(ordered_dict)
    print("Construction time:", time.monotonic() - t0)
    jit_v_experimental = torch.jit.script(v_experimental)

    # existing Vocab eager lookup
    print("Vocab - Eager Mode")
    _run_benchmark_lookup(tokens, v_existing)

    # experimental Vocab eager lookup
    print("Vocab Experimental - Eager Mode")
    _run_benchmark_lookup(tokens, v_experimental)

    # experimental Vocab jit lookup
    print("Vocab Experimental - Jit Mode")
    _run_benchmark_lookup(tokens, jit_v_experimental)
def compare_legacy_and_experimental_batch_lookup():
    num_tokens = 1000
    num_letters = 6
    num_lines = 100000
    vocab = [
        ''.join(random.sample(string.ascii_letters * num_letters, num_letters))
        for _ in range(num_tokens)
    ]
    counter = Counter()
    counter.update(vocab)
    legacy_vocab = Vocab(counter)
    experimental_vocab = VocabExperimental(counter)
    speed_ups = []
    token_lengths = [i for i in range(2, 100)]
    for i in token_lengths:
        lines = [random.sample(vocab, i) for _ in range(num_lines)]
        start_time = timer()
        for text in lines:
            legacy_vocab.lookup_indices(text)
        legacy_time = timer() - start_time

        start_time = timer()
        for text in lines:
            experimental_vocab.lookup_indices(text)

        experimental_time = timer() - start_time

        speed_ups.append(legacy_time / experimental_time)
        print("speed-up={} for average length={}".format(
            legacy_time / experimental_time, i))
        del lines

    plt.close()
    fig, ax = plt.subplots(1, 1)
    ax.plot(token_lengths, speed_ups)
    ax.set_xlabel('Average Tokens per line')
    ax.set_ylabel('Speed-up')
    plt.savefig("speedup.jpg")
def benchmark_experimental_vocab_lookup(vocab_file_path=None):
    def _run_benchmark_lookup(tokens, vocab):
        t0 = time.monotonic()
        # list lookup
        if isinstance(tokens, list) and isinstance(tokens[0], list):
            for tokens_list in tokens:
                vocab.lookup_indices(tokens_list)
        # single token lookup
        elif isinstance(tokens, list):
            for token in tokens:
                vocab[token]
        else:
            raise RuntimeError("Received tokens of incorrect type {}.".format(
                type(tokens)))
        print("Lookup time:", time.monotonic() - t0)

    tokens = []
    tokens_lists = []

    train, = AG_NEWS(data_select='train')
    vocab = train.get_vocab()
    for (_, text) in train:
        cur_tokens = []
        for id in text.tolist():
            cur_tokens.append(vocab.itos[id])
        tokens_lists.append(cur_tokens)
        tokens += cur_tokens

    if vocab_file_path:
        print("Loading Vocab from file {}".format(vocab_file_path))

        def token_iterator(file_path):
            f = open(file_path, 'r')
            for token in f:
                yield token

        # existing Vocab construction
        print("Vocab")
        t0 = time.monotonic()
        v_existing = build_vocab_from_iterator(token_iterator(vocab_file_path))
        print("Construction time:", time.monotonic() - t0)

        # experimental Vocab construction
        print("Vocab Experimental")
        t0 = time.monotonic()
        f = open(vocab_file_path, 'r')
        v_experimental = load_vocab_from_file(f)
        print("Construction time:", time.monotonic() - t0)
    else:
        print("Loading Vocab from AG News")
        counter = Counter(tokens)
        sorted_by_freq_tuples = sorted(counter.items(),
                                       key=lambda x: x[1],
                                       reverse=True)
        ordered_dict = OrderedDict(sorted_by_freq_tuples)

        # existing Vocab construction
        print("Vocab")
        t0 = time.monotonic()
        v_existing = Vocab(counter)
        print("Construction time:", time.monotonic() - t0)

        # experimental Vocab construction
        print("Vocab Experimental")
        t0 = time.monotonic()
        v_experimental = VocabExperimental(ordered_dict)
        print("Construction time:", time.monotonic() - t0)
    jit_v_experimental = torch.jit.script(v_experimental)

    # existing Vocab eager lookup
    print("Vocab - Eager Mode")
    _run_benchmark_lookup(tokens, v_existing)
    _run_benchmark_lookup([tokens], v_existing)
    _run_benchmark_lookup(tokens_lists, v_existing)

    # experimental Vocab eager lookup
    print("Vocab Experimental - Eager Mode")
    _run_benchmark_lookup(tokens, v_experimental)
    _run_benchmark_lookup([tokens], v_experimental)
    _run_benchmark_lookup(tokens_lists, v_experimental)

    jit_v_experimental = torch.jit.script(v_experimental)
    # experimental Vocab jit lookup
    print("Vocab Experimental - Jit Mode")
    _run_benchmark_lookup(tokens, jit_v_experimental)
    _run_benchmark_lookup([tokens], jit_v_experimental)
    _run_benchmark_lookup(tokens_lists, jit_v_experimental)