コード例 #1
0
def main():

    print("Creating environment...")
    environment = gym_tetris.make('Tetris-v0')

    print("Creating model...")
    model = modelutils.create_model(number_of_actions)
    model.summary()

    print("Creating agent...")
    if agent_type == "dqn":
        agent = DQNAgent(
            name="tetris-dqn",
            environment=environment,
            model=model,
            observation_transformation=utils.resize_and_bgr2gray,
            observation_frames=4,
            number_of_iterations=1000000,
            gamma=0.95,
            final_epsilon=0.01,
            initial_epsilon=1.0,
            replay_memory_size=2000,
            minibatch_size=32
        )
    elif agent_type == "ddqn":
        agent = DDQNAgent(
            name="tetris-ddqn",
            environment=environment,
            model=model,
            observation_transformation=utils.resize_and_bgr2gray,
            observation_frames=4,
            number_of_iterations=1000000,
            gamma=0.95,
            final_epsilon=0.01,
            initial_epsilon=1.0,
            replay_memory_size=2000,
            minibatch_size=32,
            model_copy_interval=100
        )
    agent.enable_rewards_tracking(rewards_running_means_length=10000)
    agent.enable_episodes_tracking(episodes_running_means_length=100)
    agent.enable_maxq_tracking(maxq_running_means_length=10000)
    agent.enable_model_saving(model_save_frequency=10000)
    agent.enable_plots_saving(plots_save_frequency=10000)

    print("Training ...")
    agent.fit(verbose=True, headless="headless" in sys.argv, render_states=True)
コード例 #2
0
def main():

    print("Creating model...")
    model = create_model()
    model.summary()

    print("Creating environment...")
    environment = gym.make("CartPole-v0")
    environment._max_episode_steps = 500

    print("Creating agent...")
    if agent_type == "dqn":
        agent = DQNAgent(name="cartpole-dqn",
                         model=model,
                         environment=environment,
                         observation_frames=1,
                         observation_transformation=observation_transformation,
                         reward_transformation=reward_transformation,
                         gamma=0.95,
                         final_epsilon=0.01,
                         initial_epsilon=1.0,
                         number_of_iterations=1000000,
                         replay_memory_size=2000,
                         minibatch_size=32)
    elif agent_type == "ddqn":
        agent = DDQNAgent(
            name="cartpole-ddqn",
            model=model,
            environment=environment,
            observation_frames=1,
            observation_transformation=observation_transformation,
            reward_transformation=reward_transformation,
            gamma=0.95,
            final_epsilon=0.01,
            initial_epsilon=1.0,
            number_of_iterations=1000000,
            replay_memory_size=2000,
            minibatch_size=32,
            model_copy_interval=100)
    agent.enable_rewards_tracking(rewards_running_means_length=10000)
    agent.enable_episodes_tracking(episodes_running_means_length=10000)
    agent.enable_maxq_tracking(maxq_running_means_length=10000)
    agent.enable_model_saving(model_save_frequency=100000)
    agent.enable_tensorboard_for_tracking()

    print("Training ...")
    agent.fit(verbose=True, headless="render" not in sys.argv)
コード例 #3
0
ファイル: train-doom.py プロジェクト: dirkh24/dqn_playground
def main():

    print("Creating model...")
    model = modelutils.create_model(number_of_actions=4)
    model.summary()

    print("Creating agent...")
    if agent_type == "dqn":
        agent = DQNAgent(name="doom-dqn",
                         model=model,
                         number_of_actions=4,
                         gamma=0.99,
                         final_epsilon=0.0001,
                         initial_epsilon=0.1,
                         number_of_iterations=200000,
                         replay_memory_size=10000,
                         minibatch_size=32)
    elif agent_type == "ddqn":
        agent = DDQNAgent(name="doom-ddqn",
                          model=model,
                          number_of_actions=4,
                          gamma=0.99,
                          final_epsilon=0.0001,
                          initial_epsilon=0.1,
                          number_of_iterations=200000,
                          replay_memory_size=10000,
                          minibatch_size=32,
                          model_copy_interval=100)
    agent.enable_rewards_tracking(rewards_running_means_length=1000)
    agent.enable_episodes_tracking(episodes_running_means_length=1000)
    agent.enable_maxq_tracking(maxq_running_means_length=1000)
    agent.enable_model_saving(model_save_frequency=10000)
    agent.enable_plots_saving(plots_save_frequency=10000)

    print("Creating game...")
    #environment = Environment(headless=("headless" in sys.argv))
    # Create an instance of the Doom game.
    environment = DoomGame()
    environment.load_config("scenarios/basic.cfg")
    environment.set_screen_format(ScreenFormat.GRAY8)
    environment.set_window_visible("headless" not in sys.argv)
    environment.init()

    print("Training ...")
    train(agent, environment, verbose="verbose" in sys.argv)
コード例 #4
0
def main():

    print("Creating model...")
    model = modelutils.create_model(number_of_actions)
    model.summary()

    print("Creating agent...")
    if agent_type == "dqn":
        agent = DQNAgent(name="supermario-dqn",
                         model=model,
                         number_of_actions=number_of_actions,
                         gamma=0.95,
                         final_epsilon=0.01,
                         initial_epsilon=1.0,
                         number_of_iterations=1000000,
                         replay_memory_size=2000,
                         minibatch_size=32)
    elif agent_type == "ddqn":
        agent = DDQNAgent(name="supermario-ddqn",
                          model=model,
                          number_of_actions=number_of_actions,
                          gamma=0.95,
                          final_epsilon=0.01,
                          initial_epsilon=1.0,
                          number_of_iterations=1000000,
                          replay_memory_size=2000,
                          minibatch_size=32,
                          model_copy_interval=100)
    agent.enable_rewards_tracking(rewards_running_means_length=10000)
    agent.enable_episodes_tracking(episodes_running_means_length=100)
    agent.enable_maxq_tracking(maxq_running_means_length=10000)
    agent.enable_model_saving(model_save_frequency=10000)
    agent.enable_plots_saving(plots_save_frequency=10000)

    print("Creating game...")
    environment = gym_super_mario_bros.make("SuperMarioBros-v0")
    environment = BinarySpaceToDiscreteSpaceEnv(environment, actions)

    print("Training ...")
    train(agent,
          environment,
          verbose="verbose" in sys.argv,
          headless="headless" in sys.argv)