Example #1
0
 def test_reconstruct_b_unknown(self):
     system = create_system_without_B()
     dmdc = DMDc(svd_rank=-1, opt=True)
     dmdc.fit(system['snapshots'], system['u'])
     np.testing.assert_array_almost_equal(dmdc.reconstructed_data(),
                                          system['snapshots'],
                                          decimal=6)
Example #2
0
 def test_atilde_b_unknown(self):
     system = create_system_without_B()
     dmdc = DMDc(svd_rank=-1, opt=True)
     dmdc.fit(system['snapshots'], system['u'])
     expected_atilde = dmdc.basis.T.conj().dot(system['A']).dot(dmdc.basis)
     np.testing.assert_array_almost_equal(
         dmdc.atilde, expected_atilde, decimal=1)
Example #3
0
 def test_B_b_known(self):
     system = create_system_with_B()
     dmdc = DMDc(svd_rank=-1)
     dmdc.fit(system['snapshots'], system['u'], system['B'])
     np.testing.assert_array_almost_equal(dmdc.B, system['B'])
Example #4
0
 def test_modes_b_unknown(self):
     system = create_system_without_B()
     dmdc = DMDc(svd_rank=3, opt=False, svd_rank_omega=4)
     dmdc.fit(system['snapshots'], system['u'])
     self.assertEqual(dmdc.modes.shape[1], 3)
Example #5
0
 def test_eigs_b_known(self):
     system = create_system_with_B()
     dmdc = DMDc(svd_rank=-1)
     dmdc.fit(system['snapshots'], system['u'], system['B'])
     real_eigs = np.array([0.1, 1.5])
     np.testing.assert_array_almost_equal(dmdc.eigs, real_eigs)
Example #6
0
def create_system(n, m):
    A = scipy.linalg.helmert(n, True)
    B = np.random.rand(n, n) - .5
    x0 = np.array([0.25] * n)
    u = np.random.rand(n, m - 1) - .5
    snapshots = [x0]
    for i in range(m - 1):
        snapshots.append(A.dot(snapshots[i]) + B.dot(u[:, i]))
    snapshots = np.array(snapshots).T
    return {'snapshots': snapshots, 'u': u, 'B': B, 'A': A}


s = create_system(25, 10)
print(s['snapshots'].shape)

dmdc = DMDc(svd_rank=-1)
dmdc.fit(s['snapshots'], s['u'])

plt.figure(figsize=(16, 6))
plt.subplot(121)
plt.title('Original system')
plt.pcolor(s['snapshots'].real)
plt.colorbar()

plt.subplot(122)
plt.title('Reconstructed system')
plt.pcolor(dmdc.reconstructed_data().real)
plt.colorbar()
plt.show()

new_u = np.exp(s['u'])
Example #7
0
 def test_btilde_b_unknown(self):
     dmdc = DMDc(svd_rank=-1)
     dmdc.fit(snapshots, control)
     expected_btilde = np.array([[-0.05836184, 0.31070992]]).T
     np.testing.assert_array_almost_equal(dmdc.btilde, expected_btilde)
Example #8
0
 def test_btilde_b_known(self):
     dmdc = DMDc(svd_rank=-1)
     dmdc.fit(snapshots, control, b)
     np.testing.assert_array_almost_equal(dmdc.btilde, b)
Example #9
0
 def test_reconstruct_b_unknown(self):
     dmdc = DMDc(svd_rank=-1)
     dmdc.fit(snapshots, control)
     np.testing.assert_array_almost_equal(dmdc.reconstructed_data,
                                          snapshots[:, 1:])
Example #10
0
 def test_atilde_b_known(self):
     dmdc = DMDc(svd_rank=-1)
     dmdc.fit(snapshots, control, b)
     real_atilde = np.array([[1.5, 0], [0, 0.1]])
     np.testing.assert_array_almost_equal(dmdc.atilde, real_atilde)
Example #11
0
def main(mode="train", sizes=[50], initflag=True):

    if initflag:
        env = gym.make("balancebot-v0")
    if mode == "train":
        model = deepq(policy=LnMlpPolicy,
                      env=env,
                      double_q=True,
                      prioritized_replay=True,
                      learning_rate=1e-3,
                      buffer_size=10000,
                      verbose=0,
                      tensorboard_log="./dqn_balancebot_tensorboard")
        model.learn(total_timesteps=100000, callback=callback)
        print("Saving model to balance_dqn.pkl")
        model.save("balance_dqn.pkl")

        del model  # remove to demonstrate saving and loading

    if mode == "test":
        model = deepq.load("balance_dqn.pkl")

        for size in sizes:
            dmdc = DMDc(svd_rank=-1)
            obs = env.reset(testmode=True)
            done = False
            env.set_done(2000)
            error = []
            fitflag = 0

            while not done:
                action, _states = model.predict(obs)
                action = 7 if action > 4 else 1
                obs, rewards, done, info = env.step(action)
                # env.render()
                # print(obs)

                if len(env.state_queue) > size:
                    snapshots = env.get_states(size=size)
                    u = env.get_inputs(size=size)

                    if fitflag % 50 == 0:
                        dmdc.fit(snapshots, u)
                        # fitflag = False
                        # print(fitflag)
                    else:
                        dmdc._snapshots = snapshots
                        dmdc._controlin = u

                    fitflag += 1

                    diff = np.linalg.norm(
                        dmdc.reconstructed_data(u)[:, 2].real -
                        snapshots[:, 2].real)
                    error.append(diff)

                    if np.isnan(diff):
                        print(dmdc.reconstructed_data().real[0], dmdc.eigs,
                              np.log(dmdc.eigs))

                    # plt.figure(figsize=(16, 6))
                    # plt.figure()
                    #
                    # plt.subplot(311)
                    # plt.title('1')
                    # # plt.pcolor(snapshots.real[0, :])
                    # plt.plot(snapshots.real[:, 0])
                    # plt.plot(dmdc.reconstructed_data().real[:, 0])
                    # # plt.colorbar()
                    #
                    # plt.subplot(312)
                    # plt.title('2')
                    # plt.plot(snapshots.real[:, 1])
                    # plt.plot(dmdc.reconstructed_data().real[:, 1])
                    # # plt.pcolor(dmdc.reconstructed_data().real)
                    # # plt.colorbar()
                    #
                    # plt.subplot(313)
                    # plt.title('3')
                    # plt.plot(snapshots.real[:, 2])
                    # plt.plot(dmdc.reconstructed_data().real[:, 2])
                    #
                    # plt.show()

            plt.plot(error)
            print(error)