def valid_npi(questions: list, pretrained_encoder_weights: str, pretrained_npi_weights: str): state_encoder = StateEncoder().to(config.device) state_encoder.load_state_dict(torch.load(pretrained_encoder_weights)) npi = NPI(state_encoder, max_depth=20, max_steps=10000).to(config.device) npi.load_state_dict(torch.load(pretrained_npi_weights)) env = AdditionEnv() add_program: dict = {'pgid': 2, 'args': []} wc: int = 0 correct: int = 0 npi.eval().to(config.device) loop = tqdm(questions, postfix='correct: {correct} wrong: {wrong}') for addend, augend in loop: npi.reset() with torch.no_grad(): env.setup(addend, augend) # run npi algorithm npi.step(env, add_program['pgid'], add_program['args']) if env.result != (addend + augend): wc += 1 loop.write('{:>5} + {:>5} = {:>5}'.format(addend, augend, env.result)) else: correct += 1 loop.set_postfix(correct=correct, wrong=wc) return correct, wc
def train_with_plot(npi: NPI, optimizer, steplists: list, epochs: int = 100, skip_correct: bool = False): arg_num, arg_depth = config.arg_shape train_loss: list = [] valid_loss: list = [] # scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=1e-1, last_epoch=-1) # scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=1e-1, patience=2) for epoch in range(epochs): npi.train().to(config.device) # initialize corrent / wrong count losses: list = [] # np.random.shuffle(steplists) loop = tqdm(steplists, ncols=100) loop.write('epoch: {}/{}'.format(epoch + 1, epochs)) for idx, step in enumerate(loop): question, trace = step['question'], step['trace'] npi.reset() for env, (pgid, args), (pgid_out, args_out), term_out in trace: optimizer.zero_grad() weights = [1] + [1 if 0 <= pgid < 6 else 1e-10] + [ 1e-10 if np.argmax(arg) == (arg_depth - 1) else 1 for arg in args_out ] # get environment observation term_pred, pgid_pred, args_pred = npi( tensor(env).flatten().type(torch.FloatTensor).to( config.device), tensor(pgid).to(config.device), tensor(args).flatten().type(torch.FloatTensor).to( config.device)) total_loss = npi_criteria( term_pred, tensor(term_out).type(torch.FloatTensor).to(config.device), pgid_pred, tensor(pgid_out).to(config.device), args_pred, tensor(args_out).type(torch.FloatTensor).to(config.device), weights) total_loss.backward() optimizer.step() losses.append(total_loss.item()) # total_loss loop.set_postfix(loss=np.average(losses)) loop.close() vloss, acc = validate(npi, steplists) valid_loss.append(vloss) train_loss.append(np.average(losses)) xlabel = np.array(range(len(train_loss))) + 1 plt.plot(xlabel, train_loss, 'b') plt.plot(xlabel, valid_loss, 'g') plt.ylabel('loss') plt.xlabel('epochs') plt.savefig(f'{config.outdir}/loss.png') # scheduler.step() loop.close() npi.save() if acc == 1.: return True
def test_question(question: list, npi: NPI) -> int: env = AdditionEnv() addend, augend = question add_program: dict = {'pgid': 2, 'args': []} npi.reset() with torch.no_grad(): env.setup(addend, augend) # run npi algorithm npi.step(env, add_program['pgid'], add_program['args']) # get environment observation return env.result
def validate(npi: NPI, steplists: list, epochs: int = 100): _, arg_depth = config.arg_shape env = AdditionEnv() valid_loss: list = [] correct = wc = 0 npi.eval().to(config.device) for step in steplists: question, trace = step['question'], step['trace'] res = test_question(question, npi) if res == np.sum(question): correct += 1 else: wc += 1 npi.reset() with torch.no_grad(): for env, (pgid, args), (pgid_out, args_out), term_out in trace: weights = [1] + [1 if 0 <= pgid < 6 else 1e-10] + [ 1e-10 if np.argmax(arg) == (arg_depth - 1) else 1 for arg in args_out ] # get environment observation term_pred, pgid_pred, args_pred = npi( tensor(env).flatten().type(torch.FloatTensor).to( config.device), tensor(pgid).to(config.device), tensor(args).flatten().type(torch.FloatTensor).to( config.device)) total_loss = npi_criteria( term_pred, tensor(term_out).type(torch.FloatTensor).to(config.device), pgid_pred, tensor(pgid_out).to(config.device), args_pred, tensor(args_out).type(torch.FloatTensor).to(config.device), weights) valid_loss.append(total_loss.item()) return np.average(valid_loss), correct / len(steplists)