#     dqn.load_weights(checkpoint_weights_filename)
    # elif os.path.isfile(weights_filename):
    #     print("Loading previous weights...")
    #     dqn.load_weights(weights_filename)
    dqn.fit(env, callbacks=callbacks, nb_steps=20000000, log_interval=10000)

    # After training is done, we save the final weights one more time.
    dqn.save_weights(weights_filename, overwrite=True)

    # Finally, evaluate our algorithm for 10 episodes.
    dqn.test(env, nb_episodes=10, visualize=False)
elif args.mode == 'test':
    weights_filename = 'wts/dqn_Breakout-v0_weights_12000000_phyran.h5f'.format(
        args.env_name)
    if args.weights:
        weights_filename = args.weights
    print(env.unwrapped.get_action_meanings())
    np.random.seed(None)
    env.seed(None)
    dqn.load_weights(weights_filename)
    dqn.training = False
    dqn.test_policy = EpsilonPhysicsPolicy(
        eps_phy=0.01, eps_ran=0.00
    )  # set a small epsilon for test policy to avoid getting stuck
    env = gym.wrappers.Monitor(env,
                               "records/",
                               video_callable=lambda episode_id: True,
                               force=True)
    dqn.test(env, nb_episodes=100, visualize=False)
    env.close()
예제 #2
0
## Init RL agent
agent = DQNAgent(model=model, nb_actions=nb_actions,
    memory=memory, nb_steps_warmup=1000,
    target_model_update=1e-2, policy=policy,
    processor=MultiInputProcessor(2),
    # enable_dueling_network=True, dueling_type='avg'
)
agent.compile(Adam(lr=1e-3), metrics=['mae'])

## Comment this row if you want to start learning again
agent.load_weights('{p}/dqn_{fn}_weights.h5f'.format(p=PATH, fn=ENV_NAME))

## Train or evaluate
if TRAIN:
    agent.training = True

observation = market.reset()

while True:
    try:
        # TODO add callbacks?

        ## Agent vybiraet dejstvie
        # (candles=9(mb=>(2,4)?), tickers=4, trades=2)
        # TODO actions for multy symbols market
        action = agent.forward(observation)

        ## Execute action
        observation, reward, done, info = market.step([action])