Пример #1
0
 def test_reconstruct_b_unknown(self):
     system = create_system_without_B()
     dmdc = DMDc(svd_rank=-1, opt=True)
     dmdc.fit(system['snapshots'], system['u'])
     np.testing.assert_array_almost_equal(dmdc.reconstructed_data(),
                                          system['snapshots'],
                                          decimal=6)
Пример #2
0
 def test_reconstruct_b_known(self):
     system = create_system_with_B()
     dmdc = DMDc(svd_rank=-1)
     dmdc.fit(system['snapshots'], system['u'], system['B'])
     np.testing.assert_array_almost_equal(dmdc.reconstructed_data(),
                                          system['snapshots'])
Пример #3
0
s = create_system(25, 10)
print(s['snapshots'].shape)

dmdc = DMDc(svd_rank=-1)
dmdc.fit(s['snapshots'], s['u'])

plt.figure(figsize=(16, 6))
plt.subplot(121)
plt.title('Original system')
plt.pcolor(s['snapshots'].real)
plt.colorbar()

plt.subplot(122)
plt.title('Reconstructed system')
plt.pcolor(dmdc.reconstructed_data().real)
plt.colorbar()
plt.show()

new_u = np.exp(s['u'])

plt.figure(figsize=(8, 6))
plt.pcolor(dmdc.reconstructed_data(new_u).real)
plt.colorbar()
plt.show()

dmdc.dmd_time['dt'] = .5
new_u = np.random.rand(s['u'].shape[0], dmdc.dynamics.shape[1] - 1)

plt.figure(figsize=(8, 6))
plt.pcolor(dmdc.reconstructed_data(new_u).real)
Пример #4
0
def main(mode="train", sizes=[50], initflag=True):

    if initflag:
        env = gym.make("balancebot-v0")
    if mode == "train":
        model = deepq(policy=LnMlpPolicy,
                      env=env,
                      double_q=True,
                      prioritized_replay=True,
                      learning_rate=1e-3,
                      buffer_size=10000,
                      verbose=0,
                      tensorboard_log="./dqn_balancebot_tensorboard")
        model.learn(total_timesteps=100000, callback=callback)
        print("Saving model to balance_dqn.pkl")
        model.save("balance_dqn.pkl")

        del model  # remove to demonstrate saving and loading

    if mode == "test":
        model = deepq.load("balance_dqn.pkl")

        for size in sizes:
            dmdc = DMDc(svd_rank=-1)
            obs = env.reset(testmode=True)
            done = False
            env.set_done(2000)
            error = []
            fitflag = 0

            while not done:
                action, _states = model.predict(obs)
                action = 7 if action > 4 else 1
                obs, rewards, done, info = env.step(action)
                # env.render()
                # print(obs)

                if len(env.state_queue) > size:
                    snapshots = env.get_states(size=size)
                    u = env.get_inputs(size=size)

                    if fitflag % 50 == 0:
                        dmdc.fit(snapshots, u)
                        # fitflag = False
                        # print(fitflag)
                    else:
                        dmdc._snapshots = snapshots
                        dmdc._controlin = u

                    fitflag += 1

                    diff = np.linalg.norm(
                        dmdc.reconstructed_data(u)[:, 2].real -
                        snapshots[:, 2].real)
                    error.append(diff)

                    if np.isnan(diff):
                        print(dmdc.reconstructed_data().real[0], dmdc.eigs,
                              np.log(dmdc.eigs))

                    # plt.figure(figsize=(16, 6))
                    # plt.figure()
                    #
                    # plt.subplot(311)
                    # plt.title('1')
                    # # plt.pcolor(snapshots.real[0, :])
                    # plt.plot(snapshots.real[:, 0])
                    # plt.plot(dmdc.reconstructed_data().real[:, 0])
                    # # plt.colorbar()
                    #
                    # plt.subplot(312)
                    # plt.title('2')
                    # plt.plot(snapshots.real[:, 1])
                    # plt.plot(dmdc.reconstructed_data().real[:, 1])
                    # # plt.pcolor(dmdc.reconstructed_data().real)
                    # # plt.colorbar()
                    #
                    # plt.subplot(313)
                    # plt.title('3')
                    # plt.plot(snapshots.real[:, 2])
                    # plt.plot(dmdc.reconstructed_data().real[:, 2])
                    #
                    # plt.show()

            plt.plot(error)
            print(error)