Пример #1
0
def valid_npi(questions: list, pretrained_encoder_weights: str,
              pretrained_npi_weights: str):
    state_encoder = StateEncoder().to(config.device)
    state_encoder.load_state_dict(torch.load(pretrained_encoder_weights))
    npi = NPI(state_encoder, max_depth=20, max_steps=10000).to(config.device)
    npi.load_state_dict(torch.load(pretrained_npi_weights))
    env = AdditionEnv()
    add_program: dict = {'pgid': 2, 'args': []}
    wc: int = 0
    correct: int = 0
    npi.eval().to(config.device)
    loop = tqdm(questions, postfix='correct: {correct} wrong: {wrong}')
    for addend, augend in loop:
        npi.reset()
        with torch.no_grad():
            env.setup(addend, augend)
            # run npi algorithm
            npi.step(env, add_program['pgid'], add_program['args'])
        if env.result != (addend + augend):
            wc += 1
            loop.write('{:>5} + {:>5} = {:>5}'.format(addend, augend,
                                                      env.result))
        else:
            correct += 1
        loop.set_postfix(correct=correct, wrong=wc)
    return correct, wc
Пример #2
0
def test_question(question: list, npi: NPI) -> int:
    env = AdditionEnv()
    addend, augend = question
    add_program: dict = {'pgid': 2, 'args': []}
    npi.reset()
    with torch.no_grad():
        env.setup(addend, augend)
        # run npi algorithm
        npi.step(env, add_program['pgid'], add_program['args'])
        # get environment observation
        return env.result