def valid_npi(questions: list, pretrained_encoder_weights: str, pretrained_npi_weights: str): state_encoder = StateEncoder().to(config.device) state_encoder.load_state_dict(torch.load(pretrained_encoder_weights)) npi = NPI(state_encoder, max_depth=20, max_steps=10000).to(config.device) npi.load_state_dict(torch.load(pretrained_npi_weights)) env = AdditionEnv() add_program: dict = {'pgid': 2, 'args': []} wc: int = 0 correct: int = 0 npi.eval().to(config.device) loop = tqdm(questions, postfix='correct: {correct} wrong: {wrong}') for addend, augend in loop: npi.reset() with torch.no_grad(): env.setup(addend, augend) # run npi algorithm npi.step(env, add_program['pgid'], add_program['args']) if env.result != (addend + augend): wc += 1 loop.write('{:>5} + {:>5} = {:>5}'.format(addend, augend, env.result)) else: correct += 1 loop.set_postfix(correct=correct, wrong=wc) return correct, wc
def validate(npi: NPI, steplists: list, epochs: int = 100): _, arg_depth = config.arg_shape env = AdditionEnv() valid_loss: list = [] correct = wc = 0 npi.eval().to(config.device) for step in steplists: question, trace = step['question'], step['trace'] res = test_question(question, npi) if res == np.sum(question): correct += 1 else: wc += 1 npi.reset() with torch.no_grad(): for env, (pgid, args), (pgid_out, args_out), term_out in trace: weights = [1] + [1 if 0 <= pgid < 6 else 1e-10] + [ 1e-10 if np.argmax(arg) == (arg_depth - 1) else 1 for arg in args_out ] # get environment observation term_pred, pgid_pred, args_pred = npi( tensor(env).flatten().type(torch.FloatTensor).to( config.device), tensor(pgid).to(config.device), tensor(args).flatten().type(torch.FloatTensor).to( config.device)) total_loss = npi_criteria( term_pred, tensor(term_out).type(torch.FloatTensor).to(config.device), pgid_pred, tensor(pgid_out).to(config.device), args_pred, tensor(args_out).type(torch.FloatTensor).to(config.device), weights) valid_loss.append(total_loss.item()) return np.average(valid_loss), correct / len(steplists)