示例#1
0
def compare_V(critic, A, B, Q, R, K, T, gamma, alpha, low=-1, high=1):
    fig, ax = plt.subplots()
    colors = ['#B53737', '#2D328F']  # red, blue
    label_fontsize = 18

    states = torch.linspace(low, high).detach().reshape(100, 1)
    values = alpha * critic(states).squeeze().detach().numpy()

    ax.plot(states.numpy(),
            values,
            color=colors[0],
            label='Approx. Loss Function')
    ax.plot(states.numpy(),
            control.trueloss(A, B, Q, R, K, states.numpy(), T,
                             gamma).reshape(states.shape[0]),
            color=colors[1],
            label='Real Loss Function')

    ax.set_xlabel('x', fontsize=label_fontsize)
    ax.set_ylabel('y', fontsize=label_fontsize)
    plt.legend()

    plt.grid(True)
    plt.show()
    return
示例#2
0
def live_train(K, low=-1, high=1):
    fig = plt.figure()
    ax = fig.add_subplot(1, 1, 1)
    ax.set_xlim(low, high)
    xtest = np.linspace(low, high, 100)

    for j in range(NUM_TRIALS):
        x = np.random.randn(1).reshape(1, 1)
        for i in range(T):
            if (i + 1) % (T / 5) == 0:
                y_hat = np.array([
                    ALPHA * model(np.array(x1).reshape(1, 1)).item()
                    for x1 in xtest
                ])
                xs = xtest.reshape(xtest.size, 1)
                y = control.trueloss(A, B, Q, R, K, xs, T,
                                     GAMMA).reshape(xtest.size)

                ax.clear()
                ax.plot(xtest, y_hat, 'r-')
                ax.plot(xtest, y, 'k-')

                plt.grid(True)
                ax.set_xlabel('x', fontsize=18)
                ax.set_ylabel('y', fontsize=18)
                plt.pause(0.05)

            u = -np.matmul(K, x)
            r = np.matmul(x, np.matmul(Q, x)) + np.matmul(u, np.matmul(R, u))

            y = r + ALPHA * GAMMA * model(np.matmul(A, x) + np.matmul(B, u))
            y_hat = model.net_forward(x)

            lr = 0.001

            model.net_backward(y, y_hat, ALPHA)
            model.update_wb(lr)

            x = np.matmul(A, x) + np.matmul(B, u)

    plt.show()
    return