Exemplo n.º 1
0
def valid_npi(questions: list, pretrained_encoder_weights: str,
              pretrained_npi_weights: str):
    state_encoder = StateEncoder().to(config.device)
    state_encoder.load_state_dict(torch.load(pretrained_encoder_weights))
    npi = NPI(state_encoder, max_depth=20, max_steps=10000).to(config.device)
    npi.load_state_dict(torch.load(pretrained_npi_weights))
    env = AdditionEnv()
    add_program: dict = {'pgid': 2, 'args': []}
    wc: int = 0
    correct: int = 0
    npi.eval().to(config.device)
    loop = tqdm(questions, postfix='correct: {correct} wrong: {wrong}')
    for addend, augend in loop:
        npi.reset()
        with torch.no_grad():
            env.setup(addend, augend)
            # run npi algorithm
            npi.step(env, add_program['pgid'], add_program['args'])
        if env.result != (addend + augend):
            wc += 1
            loop.write('{:>5} + {:>5} = {:>5}'.format(addend, augend,
                                                      env.result))
        else:
            correct += 1
        loop.set_postfix(correct=correct, wrong=wc)
    return correct, wc
Exemplo n.º 2
0
def validate(npi: NPI, steplists: list, epochs: int = 100):
    _, arg_depth = config.arg_shape
    env = AdditionEnv()
    valid_loss: list = []
    correct = wc = 0
    npi.eval().to(config.device)
    for step in steplists:
        question, trace = step['question'], step['trace']
        res = test_question(question, npi)
        if res == np.sum(question):
            correct += 1
        else:
            wc += 1
        npi.reset()
        with torch.no_grad():
            for env, (pgid, args), (pgid_out, args_out), term_out in trace:
                weights = [1] + [1 if 0 <= pgid < 6 else 1e-10] + [
                    1e-10 if np.argmax(arg) == (arg_depth - 1) else 1
                    for arg in args_out
                ]
                # get environment observation
                term_pred, pgid_pred, args_pred = npi(
                    tensor(env).flatten().type(torch.FloatTensor).to(
                        config.device),
                    tensor(pgid).to(config.device),
                    tensor(args).flatten().type(torch.FloatTensor).to(
                        config.device))
                total_loss = npi_criteria(
                    term_pred,
                    tensor(term_out).type(torch.FloatTensor).to(config.device),
                    pgid_pred,
                    tensor(pgid_out).to(config.device), args_pred,
                    tensor(args_out).type(torch.FloatTensor).to(config.device),
                    weights)
                valid_loss.append(total_loss.item())
    return np.average(valid_loss), correct / len(steplists)