예제 #1
0
def valid_npi(questions: list, pretrained_encoder_weights: str,
              pretrained_npi_weights: str):
    state_encoder = StateEncoder().to(config.device)
    state_encoder.load_state_dict(torch.load(pretrained_encoder_weights))
    npi = NPI(state_encoder, max_depth=20, max_steps=10000).to(config.device)
    npi.load_state_dict(torch.load(pretrained_npi_weights))
    env = AdditionEnv()
    add_program: dict = {'pgid': 2, 'args': []}
    wc: int = 0
    correct: int = 0
    npi.eval().to(config.device)
    loop = tqdm(questions, postfix='correct: {correct} wrong: {wrong}')
    for addend, augend in loop:
        npi.reset()
        with torch.no_grad():
            env.setup(addend, augend)
            # run npi algorithm
            npi.step(env, add_program['pgid'], add_program['args'])
        if env.result != (addend + augend):
            wc += 1
            loop.write('{:>5} + {:>5} = {:>5}'.format(addend, augend,
                                                      env.result))
        else:
            correct += 1
        loop.set_postfix(correct=correct, wrong=wc)
    return correct, wc
예제 #2
0
def train_with_plot(npi: NPI,
                    optimizer,
                    steplists: list,
                    epochs: int = 100,
                    skip_correct: bool = False):
    arg_num, arg_depth = config.arg_shape
    train_loss: list = []
    valid_loss: list = []
    # scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=1e-1, last_epoch=-1)
    # scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=1e-1, patience=2)
    for epoch in range(epochs):
        npi.train().to(config.device)
        # initialize corrent / wrong count
        losses: list = []
        # np.random.shuffle(steplists)
        loop = tqdm(steplists, ncols=100)
        loop.write('epoch: {}/{}'.format(epoch + 1, epochs))
        for idx, step in enumerate(loop):
            question, trace = step['question'], step['trace']
            npi.reset()
            for env, (pgid, args), (pgid_out, args_out), term_out in trace:
                optimizer.zero_grad()
                weights = [1] + [1 if 0 <= pgid < 6 else 1e-10] + [
                    1e-10 if np.argmax(arg) == (arg_depth - 1) else 1
                    for arg in args_out
                ]
                # get environment observation
                term_pred, pgid_pred, args_pred = npi(
                    tensor(env).flatten().type(torch.FloatTensor).to(
                        config.device),
                    tensor(pgid).to(config.device),
                    tensor(args).flatten().type(torch.FloatTensor).to(
                        config.device))
                total_loss = npi_criteria(
                    term_pred,
                    tensor(term_out).type(torch.FloatTensor).to(config.device),
                    pgid_pred,
                    tensor(pgid_out).to(config.device), args_pred,
                    tensor(args_out).type(torch.FloatTensor).to(config.device),
                    weights)
                total_loss.backward()
                optimizer.step()
                losses.append(total_loss.item())
            # total_loss
            loop.set_postfix(loss=np.average(losses))
        loop.close()
        vloss, acc = validate(npi, steplists)
        valid_loss.append(vloss)
        train_loss.append(np.average(losses))
        xlabel = np.array(range(len(train_loss))) + 1
        plt.plot(xlabel, train_loss, 'b')
        plt.plot(xlabel, valid_loss, 'g')
        plt.ylabel('loss')
        plt.xlabel('epochs')
        plt.savefig(f'{config.outdir}/loss.png')
        # scheduler.step()
        loop.close()
        npi.save()
        if acc == 1.:
            return True
예제 #3
0
def test_question(question: list, npi: NPI) -> int:
    env = AdditionEnv()
    addend, augend = question
    add_program: dict = {'pgid': 2, 'args': []}
    npi.reset()
    with torch.no_grad():
        env.setup(addend, augend)
        # run npi algorithm
        npi.step(env, add_program['pgid'], add_program['args'])
        # get environment observation
        return env.result
예제 #4
0
def validate(npi: NPI, steplists: list, epochs: int = 100):
    _, arg_depth = config.arg_shape
    env = AdditionEnv()
    valid_loss: list = []
    correct = wc = 0
    npi.eval().to(config.device)
    for step in steplists:
        question, trace = step['question'], step['trace']
        res = test_question(question, npi)
        if res == np.sum(question):
            correct += 1
        else:
            wc += 1
        npi.reset()
        with torch.no_grad():
            for env, (pgid, args), (pgid_out, args_out), term_out in trace:
                weights = [1] + [1 if 0 <= pgid < 6 else 1e-10] + [
                    1e-10 if np.argmax(arg) == (arg_depth - 1) else 1
                    for arg in args_out
                ]
                # get environment observation
                term_pred, pgid_pred, args_pred = npi(
                    tensor(env).flatten().type(torch.FloatTensor).to(
                        config.device),
                    tensor(pgid).to(config.device),
                    tensor(args).flatten().type(torch.FloatTensor).to(
                        config.device))
                total_loss = npi_criteria(
                    term_pred,
                    tensor(term_out).type(torch.FloatTensor).to(config.device),
                    pgid_pred,
                    tensor(pgid_out).to(config.device), args_pred,
                    tensor(args_out).type(torch.FloatTensor).to(config.device),
                    weights)
                valid_loss.append(total_loss.item())
    return np.average(valid_loss), correct / len(steplists)